package com.open.server;

import cn.hutool.core.codec.Base64Encoder;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.core.util.URLUtil;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.open.bean.IpWhite;
import com.open.bean.OpenSystemRole;
import com.open.bean.RequestRecord;
import com.open.bean.RequestToken;
import com.open.bean.vo.MessageVo;
import com.open.component.JwtComponent;
import com.open.service.OpenSystemRoleService;
import com.open.service.RequestRecordService;
import com.open.util.*;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.ServerEndpoint;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * @author tanyongpeng
 * <p>发消息服务，最新的</p>
 **/
@Component
@ServerEndpoint("/webSocket/2023/4")
@Slf4j
public class MsgSocketServer {

    private Session session;

    private String userId;

    private static CopyOnWriteArrayList<MsgSocketServer> webSocketServers = new CopyOnWriteArrayList<>();

    @OnOpen
    public void onOpen(Session session) throws IOException {
        this.session = session;
        this.userId = StrUtil.isBlank(getUserId()) ? "" : getUserId();

        // 查找是否已经有与该用户ID关联的连接
        Optional<MsgSocketServer> existingConnection = webSocketServers.stream()
                .filter(server -> server.userId.equals(this.userId))
                .findFirst();

        // 如果存在旧连接，则关闭旧连接并从列表中移除
        if (existingConnection.isPresent()) {
            existingConnection.get().session.close();
            webSocketServers.remove(existingConnection.get());
        }

        // 将新连接添加到列表中
        webSocketServers.add(this);

        log.info("当前用户id==={}", this.userId);
        log.info("新的连接，数量：{}", webSocketServers.size());
    }


    @OnClose
    public void onClose(){
        webSocketServers.remove(this);
        log.info("新的连接，数量：{}",webSocketServers.size());
    }

    @OnMessage
    public void onMessage(String msg){
        if (StrUtil.isNotBlank(this.userId)){
            if (msg.equals("ping")) { // 添加处理心跳信号的逻辑
                try {
                    this.session.getBasicRemote().sendText("pong");
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }else {
                log.info("来自客户端的消息：" + msg);
            }
        }
        log.info("心跳检测....用户id：{}",this.userId);
    }

    public void sendMessage(String msg) {
        for (MsgSocketServer webSocketServer : webSocketServers) {
            try {
                webSocketServer.session.getBasicRemote().sendText(msg);
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public void sendMessage(String msg, String userId) {
        for (MsgSocketServer webSocketServer : webSocketServers) {
            if (webSocketServer.userId.equals(userId)) {
                try {
                    if (webSocketServer.session.isOpen()) {
                        webSocketServer.session.getBasicRemote().sendText(msg);
                        log.info("成功========== user: {}", userId);
                    } else {
                        log.error("WebSocket connection is closed for user: {}", userId);
                    }
                } catch (IOException e) {
                    e.printStackTrace();
                }
                break;
            }
        }
    }

    public CompletableFuture<Void> msg(String wId, String code, String content, RequestRecordService requestRecordService, List<RequestToken> requestTokenList,
                                       Executor executor, String system, OpenSystemRoleService openSystemRoleService, String modelType,
                                       IpWhite ipWhite,String imgTag,String fileText){
        return CompletableFuture.runAsync(() -> {
            long start = System.currentTimeMillis();
            RedisUtil.set(wId+"sendMsg", 1, 120);
            List<RequestRecord> requestRecordList = requestRecordService.lambdaQuery()
                    .eq(RequestRecord::getWid, wId.split("_")[0])
                    .eq(RequestRecord::getIsDelete, 0)
                    .eq(RequestRecord::getGroupCode,code)
                    .orderByDesc(RequestRecord::getRequestId)
                    .last("limit "+ipWhite.getDialogueCount()).list();
            Collections.reverse(requestRecordList);
            List<MessageVo> messageVoList = new ArrayList<>();
            if (modelType.equals("gpt-3.5-turbo-identify")) {
                if (StrUtil.isNotBlank(imgTag)){
                    messageVoList.add(new MessageVo("system","你现在是一台图片识别机器，我将会告诉你图片内容场景，图片内容为："+imgTag+" ，我会根据图片内容提问"));
                }
            }else {
                if (StrUtil.isNotBlank(system) && !system.equals("null")){
                    OpenSystemRole role = openSystemRoleService.getById(system);
                    messageVoList.add(new MessageVo("system",role.getRoleValue()));
                }
            }
            if (StrUtil.isNotBlank(fileText)){
                messageVoList.add(new MessageVo("system",fileText));
            }
            String newModel = modelType.replace("-identify","");
            for (RequestRecord requestRecord : requestRecordList) {
                messageVoList.add(new MessageVo("user",requestRecord.getRequestContent()));
                messageVoList.add(new MessageVo("assistant",requestRecord.getRespondContent()));
            }
            messageVoList.add(new MessageVo("user", URLUtil.decode(content, Charset.defaultCharset())));
            log.info("message对象================{}", JSONUtil.parseArray(messageVoList).toString());
            if (SystemQueryUtil.isWindows()){
                for (String text : LocalhostTextUtil.text()) {
                    try {
                        Thread.sleep(50);
                        sendMessage(text, wId);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
                RedisUtil.set(wId+"sendMsg",0,120);
            }else {
                HttpRequest bodyRequest = HttpRequest.get("https://api.openai.com/v1/chat/completions")
                        .header("Content-Type", "application/json")
                        .header("Authorization", "Bearer "+requestTokenList.get(RandomUtil.getRandom().nextInt(requestTokenList.size())).getSign())
                        .header("Accept", "text/event-stream")
                        .body("{\n" +
                                "    \"messages\": "+JSONUtil.parseArray(messageVoList).toString()+",\n" +
                                "    \"model\": \""+newModel+"\",\n" +
                                "    \"temperature\": "+ipWhite.getTemperature()+",\n" +
                                "    \"top_p\": 1,\n" +
                                "    \"n\": 1,\n" +
                                "    \"stream\":true\n" +
                                "}");
                HttpResponse responseBody = bodyRequest.executeAsync();
                if (responseBody.isOk()) {
                    try (InputStream inputStream = responseBody.bodyStream();
                         InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
                         BufferedReader bufferedReader = new BufferedReader(inputStreamReader)) {
                        String line;
                        while ((line = bufferedReader.readLine()) != null) {
                            log.info("消息===={}",line);
                            if (line.startsWith("data:")) {
                                sendMessage(line, wId);
                            }
                        }
                        RedisUtil.set(wId+"sendMsg",0,120);
                    } catch (IOException e) {
                        RedisUtil.set(wId+"sendMsg",0,120);
                    } finally {
                        RedisUtil.set(wId+"sendMsg",0,120);
                        IoUtil.close(responseBody);
                    }
                } else {
                    RedisUtil.set(wId+"sendMsg",0,120);
                }
            }       
            long end = System.currentTimeMillis();
            log.info("当前请求消耗时间：{}ms",end-start);
        }, executor);
    }

    public String getUserId(){
        try {
            String token = session.getRequestParameterMap().get("token").get(0);
            DecodedJWT tokenInfo = JwtComponent.getTokenStaticInfo(token);
            String browserId = session.getRequestParameterMap().get("browserId").get(0);
            return tokenInfo.getClaim("wid").asString()+"_"+browserId;
        }catch (Exception e){
            log.info("老旧的用户连接着。。。token跟新的算法不一致导致了");
            return null;
        }
    }


}
